import os

import gym
import numpy as np
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from gym.spaces.box import Box
from gym.wrappers import Monitor
import easydict

from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize as VecNormalize_
from wrappers import TimeLimit, Monitor, FlattenObservation

from mod_envs.pp_wrapper import *
from mod_envs.tj_wrapper import *
#from mod_envs.ic3_envs.env_init import *
from mod_envs.cifar_game_env import *
import MSMTC.DigitalPose2D as poseEnv
from MSMTC.environment import env_wrapper as MSMTCWrapper
from utils import Map
from mod_envs.mg_env_fg_utils import make_environment as make_marlgrid_fg_environment
from mod_envs.mg_env_fg_utils import  get_env_name as  get_mg_fg_env_name
from mod_envs.mg_env_rbd_utils import make_environment as make_marlgrid_rbd_environment
from mod_envs.mg_env_rbd_utils import  get_env_name as  get_mg_rbd_env_name

# from vmas import make_env as vmas_make_env

class MADummyVecEnv(DummyVecEnv):
    def __init__(self, env_fns):
        super().__init__(env_fns)
        agents = len(self.observation_space)
        # change this because we want >1 reward
        self.buf_rews = np.zeros((self.num_envs, agents), dtype=np.float32)

def make_env(env_id, env_configs, seed, rank, time_limit, wrappers, env_properties, monitor_dir):
    def _thunk():
        if('rware' in env_id and env_properties != None):
            env = gym.make(env_id, sensor_range = env_properties['sensor_range'], request_queue_size = env_properties['request_queue_size'])
            env.seed(seed + rank)
        elif('ic3net' in env_id):
            if('predator_prey' in env_id):
                env = init('predator_prey', Map(env_configs))
        elif('FullCoopPredatorPreyWrapper' in env_id):
            env = FullCoopPredatorPreyWrapper(centralized = False, grid_shape = (env_configs['grid_size'], env_configs['grid_size']),
                  n_agents = env_configs['num_agents'],
                  n_preys = env_configs['num_preys'],
                  step_cost = env_configs['step_cost'],
                  max_steps = env_configs['time_limit'],
                  prey_capture_reward = env_configs['prey_capture_reward'],
                  penalty = env_configs['penalty'],
                  other_agent_visible= env_configs['other_agent_visible'],
                  prey_move_probs = tuple(env_configs['prey_move_probs'])
            )
            env.seed(seed + rank)
            for a_space in env.action_space:
                a_space.seed(seed + rank)
            # env.action_space.seed(seed + rank)
        elif('PredatorPreyWrapper' in env_id):
            env = PredatorPreyWrapper(centralized = False, grid_shape = (env_configs['grid_size'], env_configs['grid_size']),
                  n_agents = env_configs['num_agents'],
                  n_preys = env_configs['num_preys'],
                  step_cost = env_configs['step_cost'],
                  max_steps = env_configs['time_limit'],
                  prey_capture_reward = env_configs['prey_capture_reward'],
                  penalty = env_configs['penalty'],
                  other_agent_visible= env_configs['other_agent_visible'],
            )
            env.seed(seed + rank)
            env.action_space.seed(seed + rank)
        elif('TrafficJunction' in env_id):
            # Not using curriculum learning, so rate min equal to rate max
            env = TrafficJunctionWrapper(
                centralized = False,
                dim = env_configs['dim'],
                vision = env_configs['vision'],
                add_rate_min = env_configs['add_rate_max'],
                add_rate_max = env_configs['add_rate_max'],
                curr_start = 0,
                curr_end = 0,
                difficulty = env_configs['difficulty'],
                n_agents = env_configs['num_agents'],
                max_steps = env_configs['time_limit']
            )
            env.seed(seed + rank)
            # env.action_space.seed(seed + rank)

        elif('cifar' in env_id):
            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
            data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
            cifar_env_config = {'share_reward': env_configs['share_reward'], 'discrete_comm': env_configs['discrete_comm'], 'comm_size': env_configs['comm_size'], 'seed': seed, 'max_steps' : env_configs['time_limit'], 'naive' : env_configs['naive']}
            env = create_game_env(cifar_env_config, data)
            env.seed(seed + rank)
            env.action_space.seed(seed + rank)

        elif('MSMTC' in env_id):
            env = poseEnv.gym.make(env_id, False)
            # adjust env steps according to args
            env.max_steps = env_configs['time_limit']
            env.seed(seed + rank)

        elif('MarlGridFindGoal' in env_id):
            new_env_configs = easydict.EasyDict(env_configs)
            new_env_configs['env_name'] = get_mg_fg_env_name(new_env_configs)
            env = make_marlgrid_fg_environment(new_env_configs)
            env.seed(seed + rank)
        elif('MarlGridRedBlueDoors' in env_id):
            new_env_configs = easydict.EasyDict(env_configs)
            new_env_configs['env_name'] = get_mg_rbd_env_name(new_env_configs)
            env = make_marlgrid_rbd_environment(new_env_configs)
            env.seed(seed + rank)
        # elif('VMAS' in env_id):
        #     scenario_name = env_id.split('_')[1]
        #     env = vmas_make_env(scenario_name = scenario_name, num_envs = env_configs['num_envs'], device = env_configs['device'], continuous_actions = env_configs['continuous_actions'], n_agents = env_configs['n_agents'],)
        #     print(len(env))
        #     exit()
        else:
            env = gym.make(env_id)
            env.seed(seed + rank)
            env.action_space.seed(seed + rank)


        if time_limit:
            if('cifar' in env_id):
                # This is needed to triger label guessing at the end of the episode or else, the episode would end prematurely
                env = TimeLimit(env, time_limit + 1)
            else:
                env = TimeLimit(env, time_limit)

        # Remove the flatten observation wrapper
        for wrapper in wrappers[: -1]  if('cifar' in env_id or 'MarlGrid' in env_id) else wrappers:
            env = wrapper(env)

        if monitor_dir:
            env = Monitor(env, monitor_dir, lambda ep: int(ep==0), force=True, uid=str(rank))

        return env

    return _thunk


def make_vec_envs(
    env_name, env_configs, seed, dummy_vecenv, parallel, time_limit, wrappers, device, env_properties = None, monitor_dir = None
):
    envs = [
        make_env(env_name, env_configs, seed, i, time_limit, wrappers,env_properties, monitor_dir) for i in range(parallel)
    ]

    if dummy_vecenv or len(envs) == 1 or monitor_dir:
        envs = MADummyVecEnv(envs)
    else:
        try:
            envs = SubprocVecEnv(envs, start_method="fork")
        except:
            envs = SubprocVecEnv(envs, start_method="spawn")

    envs = VecPyTorch(envs, device, env_name)
    return envs


class VecPyTorch(VecEnvWrapper):
    def __init__(self, venv, device, env_name):
        """Return only every `skip`-th frame"""
        super(VecPyTorch, self).__init__(venv)
        self.device = device
        self.env_name = env_name
        # print(env_name)
        # TODO: Fix data types

    def process_obs_for_cifar(self, obs):
        obs_dict = {}
        for proc_idx in range(len(obs)):
            for k in obs[proc_idx][0].keys():
                if(k not in obs_dict.keys()):
                    obs_dict[k] = []
                per_proc_key_obs = torch.stack((torch.tensor(obs[proc_idx][0][k]).to(self.device), torch.tensor(obs[proc_idx][1][k]).to(self.device)))
                obs_dict[k].append(per_proc_key_obs)
        for k in obs_dict.keys():
            obs_dict[k] = torch.stack(obs_dict[k])
        return obs_dict

    def process_obs_for_marlgrid(self, obs):
        new_obs = []
        for a_idx in range(len(obs[0])):
            per_agent_img_obs = []
            per_agent_df_obs = []
            for proc_idx in range(len(obs)):
                per_agent_img_obs.append(obs[proc_idx][a_idx][0])
                per_agent_df_obs.append(obs[proc_idx][a_idx][1])
            new_obs.append((torch.from_numpy(np.stack(per_agent_img_obs)).to(self.device), torch.from_numpy(np.stack(per_agent_df_obs)).to(self.device)))
        return new_obs

    def reset(self):
        # num_agents x num_processes x obs_size
        obs = self.venv.reset()
        # print(np.array(obs).shape)
        # if("PredatorPrey" in self.env_name or "TrafficJunction" in self.env_name):
        #     np_obs = np.array(obs)
        #     return [torch.from_numpy(np_obs[:, i, :]) for i in range(np_obs.shape[1])]
        # elif("predator_prey" in self.env_name):
        #     np_obs = np.array(obs).squeeze()
        #     return [torch.from_numpy(np_obs[:, i, :]) for i in range(np_obs.shape[1])]
        # else:
        #     # num_agents x num_processes x obs_size
        #     return [torch.from_numpy(o).to(self.device) for o in obs]
        if('cifar' in self.env_name):
            return self.process_obs_for_cifar(obs)
        elif('MarlGrid' in self.env_name):
            return self.process_obs_for_marlgrid(obs)
        return [torch.from_numpy(o).to(self.device) for o in obs]

    def step_async(self, actions):
        if('cifar' in self.env_name):
            # Need to repackage actions here (number of processse, *action shape)
            env_actions, comm_actions = actions
            # env_actions = torch.swapaxes(torch.stack(env_actions), 0, 1).numpy()
            # comm_actions = torch.swapaxes(torch.stack(comm_actions), 0, 1).clone().detach().numpy()
            env_actions = torch.transpose(torch.stack(env_actions), 0, 1).numpy()
            comm_actions = torch.transpose(torch.stack(comm_actions), 0, 1).clone().detach().numpy()
            # env_actions = [a.squeeze().cpu().numpy() for a in env_actions]
            # comm_actions = [a.squeeze().clone().detach().cpu().numpy() for a in comm_actions]
            combined_actions = []
            for p_idx in range(env_actions.shape[0]):
                combined_actions.append((env_actions[p_idx], comm_actions[p_idx]))
            return self.venv.step_async(combined_actions)
        else:
            actions = [a.squeeze().cpu().numpy() for a in actions]
            actions = list(zip(*actions))
            return self.venv.step_async(actions)

    def step_wait(self):

        # We need (num_agent, num_process, num_features), original pp environment gives (num_processes, num_agent, num_features)
        obs, rew, done, info = self.venv.step_wait()
        # print(obs[0].shape)
        # print(len(obs))
        # if("PredatorPrey" in self.env_name or "TrafficJunction" in self.env_name):
        #     np_obs = np.array(obs)
        #     return (
        #         [torch.from_numpy(np_obs[:, i, :]).to(self.device) for i in range(np_obs.shape[1])],
        #         torch.from_numpy(rew).float().to(self.device),
        #         torch.from_numpy(done).float().to(self.device),
        #         info,
        #     )
        # elif("predator_prey" in self.env_name):
        #     np_obs = np.array(obs).squeeze()
        #     return (
        #         [torch.from_numpy(np_obs[:, i, :]).to(self.device) for i in range(np_obs.shape[1])],
        #         torch.from_numpy(rew).float().to(self.device),
        #         torch.from_numpy(done).float().to(self.device),
        #         info,
        #     )
        # else:
        #     return (
        #         [torch.from_numpy(o).float().to(self.device) for o in obs],
        #         torch.from_numpy(rew).float().to(self.device),
        #         torch.from_numpy(done).float().to(self.device),
        #         info,
        #     )
        if('cifar' in self.env_name):
            return (
                self.process_obs_for_cifar(obs),
                torch.from_numpy(rew).float().to(self.device),
                torch.from_numpy(done).float().to(self.device),
                info,
            )
        elif('MarlGrid' in self.env_name):
            return (
                self.process_obs_for_marlgrid(obs) ,
                torch.from_numpy(rew).float().to(self.device),
                torch.from_numpy(done).float().to(self.device),
                info,
            )
        else:
            return (
                [torch.from_numpy(o).float().to(self.device) for o in obs],
                torch.from_numpy(rew).float().to(self.device),
                torch.from_numpy(done).float().to(self.device),
                info,
            )
